{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建模型配置文件\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class TransformerConfig:\n",
    "    block_size: int = 1024\n",
    "    vocab_size: int = 50304 \n",
    "    n_layer: int = 4\n",
    "    n_head: int = 4\n",
    "    n_embd: int = 768\n",
    "    dropout: float = 0.0\n",
    "    bias: bool = True \n",
    "\n",
    "model_config = TransformerConfig(vocab_size=10, block_size=12, n_layer=2, n_head=4, n_embd=16, dropout=0.0, bias=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of parameters: 0.02M\n"
     ]
    }
   ],
   "source": [
    "# 创建模型\n",
    "\n",
    "from tiny_transformer import Transformer\n",
    "\n",
    "model = Transformer(model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "idx torch.Size([4, 8])\n",
      "tok_emb torch.Size([4, 8, 16])\n",
      "x after wpe: torch.Size([4, 8, 16])\n",
      "enc_out: torch.Size([4, 8, 16])\n",
      "x after decoder: torch.Size([4, 8, 16])\n",
      "logits torch.Size([4, 1, 10])\n"
     ]
    }
   ],
   "source": [
    "# 前向传递\n",
    "\n",
    "import torch\n",
    "\n",
    "idx = torch.randint(1, 10, (4, 8))\n",
    "logits, _ = model(idx)\n",
    "print(\"logits\",logits.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "idx torch.Size([4, 8])\n",
      "tok_emb torch.Size([4, 8, 16])\n",
      "x after wpe: torch.Size([4, 8, 16])\n",
      "enc_out: torch.Size([4, 8, 16])\n",
      "x after decoder: torch.Size([4, 8, 16])\n",
      "idx torch.Size([4, 9])\n",
      "tok_emb torch.Size([4, 9, 16])\n",
      "x after wpe: torch.Size([4, 9, 16])\n",
      "enc_out: torch.Size([4, 9, 16])\n",
      "x after decoder: torch.Size([4, 9, 16])\n",
      "idx torch.Size([4, 10])\n",
      "tok_emb torch.Size([4, 10, 16])\n",
      "x after wpe: torch.Size([4, 10, 16])\n",
      "enc_out: torch.Size([4, 10, 16])\n",
      "x after decoder: torch.Size([4, 10, 16])\n",
      "generate result torch.Size([4, 11])\n"
     ]
    }
   ],
   "source": [
    "# 推理\n",
    "result = model.generate(idx, 3)\n",
    "print(\"generate result\",result.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 5, 7, 3, 3, 4, 4, 8, 1, 6, 7],\n",
       "        [1, 3, 3, 8, 8, 4, 8, 3, 6, 6, 5],\n",
       "        [5, 1, 3, 6, 7, 2, 8, 1, 6, 1, 6],\n",
       "        [3, 5, 7, 8, 6, 8, 2, 1, 2, 7, 9]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 生成结果\n",
    "result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
